{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0b850cd9",
   "metadata": {},
   "source": [
    "# Lesson 1 - Text Generation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "130b83d4",
   "metadata": {},
   "source": [
    "Welcome to Lesson 1.\n",
    "\n",
    "To access the `requirements.txt` file, go to `File` and click on `Open`.\n",
    "\n",
    "In this lesson, we'll cover the following:\n",
    "\n",
    "1. How to load an LLM from HuggingFace?\n",
    "2. How to generate a token from the model output tensors?\n",
    "3. Prefill and decode: optimizing token generation over multiple steps\n",
    "\n",
    "I hope you enjoy this course!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "355d0908-936e-411e-b7ab-0bad3c0d8bba",
   "metadata": {},
   "source": [
    "### Import required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d63d6a13-d807-4287-ac43-82c1a760a963",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import time\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c49df59a",
   "metadata": {},
   "source": [
    "## 1. How to load an LLM from HuggingFace?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6df7749b-2596-460a-b5d6-69f57dd804f4",
   "metadata": {},
   "source": [
    "- Load the LLM\n",
    "- You'll use GPT2 throughout this course"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "921474e8-e424-4a59-8a3c-bcea49690133",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "model_name = \"./models/gpt2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82ce24db-94ce-4272-9e30-7d3b2b8ff595",
   "metadata": {},
   "source": [
    "- Examine the model's architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2412d3e-7ee2-4362-adb2-2f434d6f2416",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64648331-8f55-4150-929e-4e7a3a76e305",
   "metadata": {},
   "source": [
    "## 2. How to generate a token from the model output tensors?\n",
    "\n",
    "### Text Generation\n",
    "- Start by tokenizing the input prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c94d8b4b-13ea-42ad-917b-54974c9cdbd0",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "prompt = \"The quick brown fox jumped over the\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae14eb9a-2371-4705-ab95-f9c98a4b703c",
   "metadata": {},
   "source": [
    "- Pass the inputs to the model and retrieve the logits to find the most likely next token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "794b0abd-1940-4cb7-902e-9f9d83ec52bd",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    outputs = model(**inputs)\n",
    "\n",
    "logits = outputs.logits\n",
    "print(logits.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58ae92b3-200e-470b-b107-559ee4b4e503",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "last_logits = logits[0, -1, :]\n",
    "next_token_id = last_logits.argmax()\n",
    "next_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36f5728d-b607-448c-a4cc-9b9239c99134",
   "metadata": {},
   "source": [
    "- Decode the most likely token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82e856ac-805b-4e47-b0e6-aca4580923be",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "tokenizer.decode(next_token_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924cda2a-8b7c-480f-9c7b-6d19c1630765",
   "metadata": {},
   "source": [
    "- Print the 10 most likely next words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dc39fa7-57c9-4ad2-a18f-58e64f380f9a",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "top_k = torch.topk(last_logits, k=10)\n",
    "tokens = [tokenizer.decode(tk) for tk in top_k.indices]\n",
    "tokens"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8991969-6a3e-487b-abb9-5381e7bfbcd4",
   "metadata": {},
   "source": [
    "- Concatenate the input and most likely tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8aabe35-efbe-439c-9284-1407fc0b2cab",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "next_inputs = {\n",
    "    \"input_ids\": torch.cat(\n",
    "        [inputs[\"input_ids\"], next_token_id.reshape((1, 1))],\n",
    "        dim=1\n",
    "    ),\n",
    "    \"attention_mask\": torch.cat(\n",
    "        [inputs[\"attention_mask\"], torch.tensor([[1]])],\n",
    "        dim=1\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449af38e-14b8-4f5d-a972-db44f94000b1",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "print(next_inputs[\"input_ids\"],\n",
    "      next_inputs[\"input_ids\"].shape)\n",
    "print(next_inputs[\"attention_mask\"],\n",
    "      next_inputs[\"attention_mask\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c74fb045",
   "metadata": {},
   "source": [
    "## 3. Prefill and decode: optimizing token generation over multiple steps"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3681d8a1-4e14-4d99-9f78-f122c5fc86ad",
   "metadata": {},
   "source": [
    "### Text generation helper function\n",
    "- The following helper function generates the next tokens given a set of input tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a4605d3-1506-4b80-a6b9-b2e8d5ccea72",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "def generate_token(inputs):\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    last_logits = logits[0, -1, :]\n",
    "    next_token_id = last_logits.argmax()\n",
    "    return next_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f65c82bc-e667-41b5-b0b5-b1a222b22cf0",
   "metadata": {},
   "source": [
    "- Use the helper function to generate multiple tokens in a loop\n",
    "- Track the time it takes to generate each token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8312f1a-9a24-4206-8878-022a4db7d5ed",
   "metadata": {
    "height": 403
   },
   "outputs": [],
   "source": [
    "generated_tokens = []\n",
    "next_inputs = inputs\n",
    "durations_s = []\n",
    "for _ in range(10):\n",
    "    t0 = time.time()\n",
    "    next_token_id = generate_token(next_inputs)\n",
    "    durations_s += [time.time() - t0]\n",
    "    \n",
    "    next_inputs = {\n",
    "        \"input_ids\": torch.cat(\n",
    "            [next_inputs[\"input_ids\"], next_token_id.reshape((1, 1))],\n",
    "            dim=1),\n",
    "        \"attention_mask\": torch.cat(\n",
    "            [next_inputs[\"attention_mask\"], torch.tensor([[1]])],\n",
    "            dim=1),\n",
    "    }\n",
    "    \n",
    "    next_token = tokenizer.decode(next_token_id)\n",
    "    generated_tokens.append(next_token)\n",
    "\n",
    "print(f\"{sum(durations_s)} s\")\n",
    "print(generated_tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e42e3ef-b906-4380-b44e-9aa10ea5da34",
   "metadata": {},
   "source": [
    "- Plot token generation time\n",
    "- The x-axis here is the token number\n",
    "- The y-axis is the time to generate a token in millisenconds (ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40d97f7b",
   "metadata": {},
   "source": [
    "**Note**: Your plot may vary slightly from the one shown in the video, yet it will exhibit a similar pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec841dd0-5be5-4da7-9a03-4af11449204e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "plt.plot(durations_s)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "112221a8-0172-43a8-a6b8-b1fd6aa0fae1",
   "metadata": {},
   "source": [
    "### Speeding up text generation with KV-caching\n",
    "KV-caching is a technique to speed up token generation by storing some of the tensors in the attention head for use in subsequent generation steps\n",
    "- Modify the generate helper function to return the next token and the key/value tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3f6fab4-c3ca-4efa-9533-2915667e5058",
   "metadata": {
    "height": 148
   },
   "outputs": [],
   "source": [
    "def generate_token_with_past(inputs):\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    last_logits = logits[0, -1, :]\n",
    "    next_token_id = last_logits.argmax()\n",
    "    return next_token_id, outputs.past_key_values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae66de3f-2b2b-47fd-bc0a-ad1aac807cb8",
   "metadata": {},
   "source": [
    "- Generate 10 tokens using the updated helper function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98bb1897",
   "metadata": {
    "height": 386
   },
   "outputs": [],
   "source": [
    "generated_tokens = []\n",
    "next_inputs = inputs\n",
    "durations_cached_s = []\n",
    "for _ in range(10):\n",
    "    t0 = time.time()\n",
    "    next_token_id, past_key_values = \\\n",
    "        generate_token_with_past(next_inputs)\n",
    "    durations_cached_s += [time.time() - t0]\n",
    "    \n",
    "    next_inputs = {\n",
    "        \"input_ids\": next_token_id.reshape((1, 1)),\n",
    "        \"attention_mask\": torch.cat(\n",
    "            [next_inputs[\"attention_mask\"], torch.tensor([[1]])],\n",
    "            dim=1),\n",
    "        \"past_key_values\": past_key_values,\n",
    "    }\n",
    "    \n",
    "    next_token = tokenizer.decode(next_token_id)\n",
    "    generated_tokens.append(next_token)\n",
    "\n",
    "print(f\"{sum(durations_cached_s)} s\")\n",
    "print(generated_tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8ea5ded-ad4a-42c2-beb8-5cbf5ab9c4e2",
   "metadata": {},
   "source": [
    "- Compare the execution time for the KV-cache function with the original helper function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a632eaf9",
   "metadata": {},
   "source": [
    "**Note**: Your plot may vary slightly from the one shown in the video, yet it will exhibit a similar pattern."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba249a24-42b0-4ce6-881c-00b1dcbb4212",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "plt.plot(durations_s)\n",
    "plt.plot(durations_cached_s)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
